import torch
import torch.nn as nn

from onpolicy.algorithms.utils.act import ACTLayer
from onpolicy.algorithms.utils.cnn import CNNBase
from onpolicy.algorithms.utils.gru import RNNLayer
from onpolicy.algorithms.utils.mlp import MLPBase_Actor, MLPBase_Critic, MLPBase_Trans
from onpolicy.algorithms.utils.popart import PopArt
from onpolicy.algorithms.utils.util import check, init
from onpolicy.utils.util import get_shape_from_obs_space


class R_Actor(nn.Module):
    """
    Actor network class for MAPPO. Outputs actions given observations.
    :param args: (argparse.Namespace) arguments containing relevant model information.
    :param obs_space: (gym.Space) observation space.
    :param action_space: (gym.Space) action space.
    :param device: (torch.device) specifies the device to run on (cpu/gpu).
    """

    def __init__(self, args, obs_space, action_space, device=torch.device("cpu")):
        super(R_Actor, self).__init__()
        self.hidden_size = args.hidden_size

        self._gain = args.gain
        self._use_orthogonal = args.use_orthogonal
        self._use_policy_active_masks = args.use_policy_active_masks
        self._use_naive_recurrent_policy = args.use_naive_recurrent_policy
        self._use_recurrent_policy = args.use_recurrent_policy
        self._recurrent_N = args.recurrent_N
        self.tpdv = dict(dtype=torch.float32, device=device)

        obs_shape = get_shape_from_obs_space(obs_space)
        self.trans = nn.Sequential(nn.Linear(1, obs_shape[0]), nn.ReLU())
        base = CNNBase if len(obs_shape) == 3 else MLPBase_Actor
        self.base = base(args, obs_shape)

        if self._use_naive_recurrent_policy or self._use_recurrent_policy:
            self.rnn = RNNLayer(
                self.hidden_size,
                self.hidden_size,
                self._recurrent_N,
                self._use_orthogonal,
            )

        self.act = ACTLayer(
            action_space, self.hidden_size, self._use_orthogonal, self._gain, args
        )
        self.trans_dim = args.n_trans
        self.trans = MLPBase_Trans(args, self.trans_dim, obs_shape[0])

        self.to(device)
        self.algo = args.algorithm_name

    def forward(
        self, obs, rnn_states, masks, available_actions=None, deterministic=False
    ):
        """
        Compute actions from the given inputs.
        :param obs: (np.ndarray / torch.Tensor) observation inputs into network.
        :param rnn_states: (np.ndarray / torch.Tensor) if RNN network, hidden states for RNN.
        :param masks: (np.ndarray / torch.Tensor) mask tensor denoting if hidden states should be reinitialized to zeros.
        :param available_actions: (np.ndarray / torch.Tensor) denotes which actions are available to agent
                                                              (if None, all actions available)
        :param deterministic: (bool) whether to sample from action distribution or return the mode.

        :return actions: (torch.Tensor) actions to take.
        :return action_log_probs: (torch.Tensor) log probabilities of taken actions.
        :return rnn_states: (torch.Tensor) updated RNN hidden states.
        """
        obs = check(obs).to(**self.tpdv)
        rnn_states = check(rnn_states).to(**self.tpdv)
        masks = check(masks).to(**self.tpdv)

        if available_actions is not None:
            available_actions = check(available_actions).to(**self.tpdv)

        if len(obs.shape) == 3:
            # 将obs从[trans_dim,n_agent*n_rollout_threads,obs_dim]转换为[1,n_agent*n_rollout_threads,obs_dim]之后在变为[n_agent*n_rollout_threads,obs_dim]
            if obs.shape[0] == self.trans_dim:
                obs = obs.permute(1, 2, 0)
                obs = self.trans(obs)
                obs = obs.permute(2, 0, 1)
                # 最后转换为 [n_agent*n_rollout_threads, obs_dim]
            obs = obs.view(-1, obs.shape[-1])

        actor_features = self.base(obs)

        if self._use_naive_recurrent_policy or self._use_recurrent_policy:
            actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks)

        actions, action_log_probs = self.act(
            actor_features, available_actions, deterministic
        )

        return actions, action_log_probs, rnn_states

    def evaluate_actions(
        self, obs, rnn_states, action, masks, available_actions=None, active_masks=None
    ):
        """
        Compute log probability and entropy of given actions.
        :param obs: (torch.Tensor) observation inputs into network.
        :param action: (torch.Tensor) actions whose entropy and log probability to evaluate.
        :param rnn_states: (torch.Tensor) if RNN network, hidden states for RNN.
        :param masks: (torch.Tensor) mask tensor denoting if hidden states should be reinitialized to zeros.
        :param available_actions: (torch.Tensor) denotes which actions are available to agent
                                                              (if None, all actions available)
        :param active_masks: (torch.Tensor) denotes whether an agent is active or dead.

        :return action_log_probs: (torch.Tensor) log probabilities of the input actions.
        :return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs.
        """
        obs = check(obs).to(**self.tpdv)
        rnn_states = check(rnn_states).to(**self.tpdv)
        action = check(action).to(**self.tpdv)
        masks = check(masks).to(**self.tpdv)
        if available_actions is not None:
            available_actions = check(available_actions).to(**self.tpdv)

        if active_masks is not None:
            active_masks = check(active_masks).to(**self.tpdv)
        if len(obs.shape) == 3:
            # 将obs从[600,trans_dim, obs_dim]转换为[1,600,obs_dim]之后在变为[n_agent*n_rollout_threads,obs_dim]
            if obs.shape[1] == self.trans_dim:
                obs = obs.permute(0, 2, 1)
                obs = self.trans(obs)
                obs = obs.permute(2, 0, 1)
                # 最后转换为 [n_agent*n_rollout_threads, obs_dim]
            obs = obs.view(-1, obs.shape[-1])

        actor_features = self.base(obs)

        if self._use_naive_recurrent_policy or self._use_recurrent_policy:
            actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks)

        if self.algo == "hatrpo":
            action_log_probs, dist_entropy, action_mu, action_std, all_probs = (
                self.act.evaluate_actions_trpo(
                    actor_features,
                    action,
                    available_actions,
                    active_masks=active_masks
                    if self._use_policy_active_masks
                    else None,
                )
            )

            return action_log_probs, dist_entropy, action_mu, action_std, all_probs
        else:
            action_log_probs, dist_entropy = self.act.evaluate_actions(
                actor_features,
                action,
                available_actions,
                active_masks=active_masks if self._use_policy_active_masks else None,
            )

        return action_log_probs, dist_entropy


class R_Critic(nn.Module):
    """
    Critic network class for MAPPO. Outputs value function predictions given centralized input (MAPPO) or
                            local observations (IPPO).
    :param args: (argparse.Namespace) arguments containing relevant model information.
    :param cent_obs_space: (gym.Space) (centralized) observation space.
    :param device: (torch.device) specifies the device to run on (cpu/gpu).
    """

    def __init__(self, args, cent_obs_space, device=torch.device("cpu")):
        super(R_Critic, self).__init__()
        self.hidden_size = args.hidden_size
        self._use_orthogonal = args.use_orthogonal
        self._use_naive_recurrent_policy = args.use_naive_recurrent_policy
        self._use_recurrent_policy = args.use_recurrent_policy
        self._recurrent_N = args.recurrent_N
        self._use_popart = args.use_popart
        self.tpdv = dict(dtype=torch.float32, device=device)
        init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][
            self._use_orthogonal
        ]

        cent_obs_shape = get_shape_from_obs_space(cent_obs_space)
        base = CNNBase if len(cent_obs_shape) == 3 else MLPBase_Critic
        self.base = base(args, cent_obs_shape)
        # self.trans_dim = args.n_trans
        # self.trans = MLPBase_Trans(args, self.trans_dim, cent_obs_shape[0])

        if self._use_naive_recurrent_policy or self._use_recurrent_policy:
            self.rnn = RNNLayer(
                self.hidden_size,
                self.hidden_size,
                self._recurrent_N,
                self._use_orthogonal,
            )

        def init_(m):
            return init(m, init_method, lambda x: nn.init.constant_(x, 0))

        if self._use_popart:
            self.v_out = init_(PopArt(self.hidden_size, 1, device=device))
        else:
            self.v_out = init_(nn.Linear(self.hidden_size, 1))

        self.to(device)

    def forward(self, cent_obs, rnn_states, masks):
        """
        Compute actions from the given inputs.
        :param cent_obs: (np.ndarray / torch.Tensor) observation inputs into network.
        :param rnn_states: (np.ndarray / torch.Tensor) if RNN network, hidden states for RNN.
        :param masks: (np.ndarray / torch.Tensor) mask tensor denoting if RNN states should be reinitialized to zeros.

        :return values: (torch.Tensor) value function predictions.
        :return rnn_states: (torch.Tensor) updated RNN hidden states.
        """
        cent_obs = check(cent_obs).to(**self.tpdv)

        # 只读取单层的数据，效果并没有提升
        # if len(cent_obs.shape) == 3:
        #     # 将cent_obs从[trans_dim,n_agent*n_rollout_threads,obs_dim]转换为[1,n_agent*n_rollout_threads,obs_dim]之后在变为[n_agent*n_rollout_threads,obs_dim]
        #     if cent_obs.shape[0] == self.trans_dim:
        #         cent_obs = cent_obs.permute(1, 2, 0)
        #         cent_obs = self.trans(cent_obs)
        #         cent_obs = cent_obs.permute(2, 0, 1)
        #         # 最后转换为 [n_agent*n_rollout_threads, obs_dim]
        #     cent_obs = cent_obs.view(-1, cent_obs.shape[-1])

        #
        # if len(cent_obs.shape) == 3:
        #     obs_dim = cent_obs.shape[2]
        #     num_2 = cent_obs.shape[1]
        #     if cent_obs.shape[0] == self.trans_dim:
        #         cent_obs = (
        #             cent_obs.permute(1, 0, 2)
        #             .contiguous()
        #             .view(num_2, self.trans_dim * obs_dim)
        #         )
        #         cent_obs = self.trans(cent_obs)
        #         # 最后转换为 [n_agent*n_rollout_threads, obs_dim]
        #         cent_obs = cent_obs.view(-1, cent_obs.shape[-1])
        rnn_states = check(rnn_states).to(**self.tpdv)
        masks = check(masks).to(**self.tpdv)

        critic_features = self.base(
            cent_obs
        )  # [n_agent*n_rollout_threads, hidden_size]
        if self._use_naive_recurrent_policy or self._use_recurrent_policy:
            critic_features, rnn_states = self.rnn(critic_features, rnn_states, masks)
        values = self.v_out(critic_features)

        return values, rnn_states
